from torchvision import models
from utils import no_update
from torch import nn


def create_model():
    feature_extract = models.resnet152(pretrained=True)
    no_update(feature_extract)
    fc = nn.Sequential(nn.Linear(2048, 1024),
                       nn.ReLU(),
                       nn.Linear(1024, 256),
                       nn.ReLU(),
                       nn.Linear(256, 10)
                       )
    feature_extract.fc = fc
    return feature_extract